{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🤗 Huggingface x Promptify 🚀"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "from promptify import HubModel, Prompter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Table of content:\n",
    "- **A.** First example: binary classification\n",
    "- **B.** Another example: multiclass classification\n",
    "- **C.** Play with options and parameters\n",
    "- **D.** Production-ready API using Inference Endpoints"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# A. First example: binary classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Initialize HubModel and Prompter\n",
    "\n",
    "`HubModel` will check if model exists on the 🤗 Hub and is a text2text-generation model. \n",
    "\n",
    "You can visit https://huggingface.co/models to look for a model suiting your needs. Default is [`google/flan-t5-xl`](https://huggingface.co/google/flan-t5-xl), a popular text-generation model fine-tuned on more than 1000 additional tasks covering 60 languages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HubModel(model_id_or_url=\"google/flan-t5-xl\")\n",
    "prompter = Prompter(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load samples from a binary classification dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got 9 binary classification examples.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "binary_examples = json.load(open(\"data/binary.json\",'r'))\n",
    "print(\"Got\", len(binary_examples), \"binary classification examples.\")\n",
    "\n",
    "prompt_examples = []\n",
    "for sample in binary_examples[:2]:\n",
    "    prompt_examples.append((sample['text'], sample['labels']))\n",
    "print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is the first sample. Contains a field `\"text\"` and a field `\"labels\"`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': 'Eight years the republicans denied obama’s picks. Breitbarters outrage is as phony as their fake president.',\n",
       " 'labels': 'negative',\n",
       " 'score': '',\n",
       " 'complexity': ''}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "binary_examples[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. (optional) Generate a prompt for binary classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Binary Classification System\n",
      "You are a highly intelligent and accurate Binary Classification system. You take Passage as input and classify that as either positive or negative Category. Your output format is only [{'C':Category}] form, no other form.\n",
      "\n",
      "Examples:\n",
      "\n",
      "Input: Eight years the republicans denied obama’s picks. Breitbarters outrage is as phony as their fake president.\n",
      "Output: [{'C': 'negative' }]\n",
      "\n",
      "Input: Except he’s the most successful president in our lifetimes. He’s undone most of the damage Obummer did and set America on the right path again.\n",
      "Output: [{'C': 'positive' }]\n",
      "\n",
      "Input: Amazing customer service.\n",
      "Output:\n"
     ]
    }
   ],
   "source": [
    "print(prompter.generate_prompt(\n",
    "    \"binary_classification.jinja\",\n",
    "    label_0=\"positive\",\n",
    "    label_1=\"negative\",\n",
    "    examples=prompt_examples,\n",
    "    text_input=\"Amazing customer service.\",\n",
    "    description=\"Binary Classification System\",\n",
    "))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Make predictions !"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 Define your `predict` method\n",
    "\n",
    "- Takes a text as input\n",
    "- Generates a prompt (using `\"binary_classification.jinja\"` template and some examples)\n",
    "- Send a request to HF Inference API\n",
    "- Search for a label in the output\n",
    "\n",
    "Note that the postprocessing step is yet to define. For classification tasks, looking for the label in the output can be enough. For more complex tasks (like NER), the output needs better post-processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(text:str) -> str:\n",
    "    output = prompter.fit(\n",
    "        \"binary_classification.jinja\",\n",
    "        label_0=\"positive\",\n",
    "        label_1=\"negative\",\n",
    "        examples=prompt_examples,\n",
    "        text_input=text,\n",
    "    )\n",
    "    if \"positive\" in output:\n",
    "        return \"positive\"\n",
    "    elif \"negative\" in output:\n",
    "        return \"negative\"\n",
    "    return \"unknown\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Test it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Eight years the republicans denied obama’s picks. Breitbarters outrage is as phony as their fake president.\n",
      "Expected:    negative\n",
      "Prediction:  negative ✅\n",
      "\n",
      "\n",
      "Except he’s the most successful president in our lifetimes. He’s undone most of the damage Obummer did and set America on the right path again.\n",
      "Expected:    positive\n",
      "Prediction:  positive ✅\n",
      "\n",
      "\n",
      "So disappointed in wwe summerslam! I want to see john cena wins his 16th title\n",
      "Expected:    negative\n",
      "Prediction:  negative ✅\n",
      "\n",
      "\n",
      "Looking forward to going to Carrow Rd tonight. Last time we were there\\u002c Bale scored 2 and we were 3rd. Do not want extra time though\n",
      "Expected:    positive\n",
      "Prediction:  positive ✅\n",
      "\n",
      "\n",
      "It's a good day at work when you get to shake Jim Lehrer's hand. Thanks, @user Still kicking myself for being to shy to hug\n",
      "Expected:    positive\n",
      "Prediction:  positive ✅\n",
      "\n",
      "\n",
      "Trumpism likewise rests on a bed of racial resentment that was made knowingly and intentionally, long before Trump got into politics.\n",
      "Expected:    negative\n",
      "Prediction:  negative ✅\n",
      "\n",
      "\n",
      "i have been with petronas for years i feel that petronas has performed well and made a huge profit\n",
      "Expected:    positive\n",
      "Prediction:  positive ✅\n",
      "\n",
      "\n",
      "I've read many post on here saying Israel gives training to US law enforcement .\n",
      "Expected:    positive\n",
      "Prediction:  positive ✅\n",
      "\n",
      "\n",
      "And the sad thing is the white students at those schools will act like that too .\n",
      "Expected:    negative\n",
      "Prediction:  negative ✅\n"
     ]
    }
   ],
   "source": [
    "for item in binary_examples:\n",
    "    prediction = predict(item[\"text\"])\n",
    "    print(\"\\n\")\n",
    "    print(item[\"text\"])\n",
    "    print(\"Expected:   \", item[\"labels\"])\n",
    "    print(\"Prediction: \", prediction, \"✅\" if prediction == item[\"labels\"] else \"❌\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# B. Another example: multiclass classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Initialize HubModel and Prompter\n",
    "\n",
    "Already done in part A."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load samples from a multiclass classification dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Got 10 multiclass classification examples.\n",
      "Labels are: {'surprise', 'worry', 'joy', 'hate', 'neutral', 'sadness'}\n"
     ]
    }
   ],
   "source": [
    "multiclass_examples = json.load(open(\"data/multiclass.json\",'r'))\n",
    "print(\"Got\", len(multiclass_examples), \"multiclass classification examples.\")\n",
    "labels = set(sample['category'] for sample in multiclass_examples)\n",
    "print(\"Labels are:\", labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. (optional) Generate a prompt for multiclass classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "You are a highly intelligent and accurate Multiclass Classification system. You take Passage as input and classify that as one of the following appropriate Categories:\n",
      "{'surprise', 'worry', 'joy', 'hate', 'neutral', 'sadness'}\n",
      "Your output format is only [{{'C': Appropriate Category from the list of provided Categories}}] form, no other form.\n",
      "\n",
      "Input: I ate Something I don't know what it is... Why do I keep Telling things about food\n",
      "Output:\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    prompter.generate_prompt(\n",
    "        \"multiclass_classification.jinja\",\n",
    "        labels=labels,\n",
    "        text_input=multiclass_examples[0][\"text\"],\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Make predictions !"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.1 Define your `predict` method\n",
    "\n",
    "- Takes a text as input\n",
    "- Generates a prompt (using `\"multiclass_classification.jinja\"` template and your labels)\n",
    "- Send a request to HF Inference API\n",
    "- Search for a label in the output\n",
    "\n",
    "Note that the postprocessing step is yet to define. For classification tasks, looking for the label in the output can be enough. For more complex tasks (like NER), the output needs better post-processing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(text:str) -> str:\n",
    "    output = prompter.fit(\"multiclass_classification.jinja\", labels=labels, text_input=text)\n",
    "    for label in labels:\n",
    "        if label in output:\n",
    "            return label\n",
    "    return \"unknown\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.2 Test it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "I ate Something I don't know what it is... Why do I keep Telling things about food\n",
      "Expected:    worry\n",
      "Prediction:  neutral ❌\n",
      "\n",
      "\n",
      "Here's to the start of a great adventure. Niners today, Alaska tomorrow.\n",
      "Expected:    joy\n",
      "Prediction:  neutral ❌\n",
      "\n",
      "\n",
      "It is so annoying when she starts typing on her computer in the middle of the night!\n",
      "Expected:    hate\n",
      "Prediction:  hate ✅\n",
      "\n",
      "\n",
      "Chocolate milk is so much better through a straw. I lack said straw\n",
      "Expected:    neutral\n",
      "Prediction:  neutral ✅\n",
      "\n",
      "\n",
      "I want to buy this great album but unfortunately i dont hav enuff funds  its &quot;long time noisy&quot;\n",
      "Expected:    sadness\n",
      "Prediction:  neutral ❌\n",
      "\n",
      "\n",
      "dont wanna work 11-830 tomorrow  but i get paid\n",
      "Expected:    sadness\n",
      "Prediction:  neutral ❌\n",
      "\n",
      "\n",
      "Oh no one minute too late! Oh well\n",
      "Expected:    worry\n",
      "Prediction:  neutral ❌\n",
      "\n",
      "\n",
      "2 days of this month left, and I only have 400MB left on my onpeak downloads.\n",
      "Expected:    surprise\n",
      "Prediction:  neutral ❌\n",
      "\n",
      "\n",
      "my last tweet didn't send  bad phone\n",
      "Expected:    neutral\n",
      "Prediction:  worry ❌\n",
      "\n",
      "\n",
      "I had a dream about a pretty pretty beach and there was no beach when I woke up\n",
      "Expected:    surprise\n",
      "Prediction:  surprise ✅\n"
     ]
    }
   ],
   "source": [
    "for item in multiclass_examples:\n",
    "    prediction = predict(item[\"text\"])\n",
    "    print(\"\\n\")\n",
    "    print(item[\"text\"])\n",
    "    print(\"Expected:   \", item[\"category\"])\n",
    "    print(\"Prediction: \", prediction, \"✅\" if prediction == item[\"category\"] else \"❌\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# C. Play with options and parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model parameters\n",
    "\n",
    "Your results might not be satisfying on the first try. You can play with the model parameters to try to get a better output. Text generation models are based on the `transformers` library, especially the \"text-generation\" pipeline. A detailed list of all parameters supported by this pipeline can be found on the [🤗 text generation doc page](https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task).\n",
    "\n",
    "Here is the list of parameters you can use, with their type and their default value:\n",
    "```py\n",
    "top_k: Optional[int] = None\n",
    "top_p: Optional[float] = None\n",
    "temperature: float = 1.0\n",
    "repetition_penalty: Optional[float] = None\n",
    "max_new_tokens: Optional[int] = None\n",
    "max_time: Optional[float] = None\n",
    "num_return_sequences: int = 1\n",
    "do_sample: bool = True\n",
    "```\n",
    "\n",
    "For classication task, those parameters might not be interested. However for tasks like summarization, you might want to look at parameters such as `max_new_tokens`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Options\n",
    "\n",
    "In addition to the pipeline parameters, you can set Inference API options. In particular:\n",
    "- `wait_for_model` (bool, defaults to `True`): Either you want to wait for the model to be loaded in the Inference API. Popular models are often already loaded but more specific models have to be pre-heated before being able to use them.\n",
    "- `use_cache` (bool, defaults to `True`): There is a cache layer on the inference API to speedup requests we have already seen. Most models can use those results as is as models are deterministic (meaning the results will be the same anyway). However if you use a non deterministic model, you can set this parameter to prevent the caching mechanism from being used resulting in a real new query."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['9']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# In this example, if the model was not already loaded, an HTTPError (503 Service Unavailable)\n",
    "# would have been raised. Also, the output is not deterministic so result might change between reruns.\n",
    "\n",
    "model.run(\n",
    "    prompts=\"What is the sum of four and five?\",\n",
    "    wait_for_model=False,\n",
    "    use_cache=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# D. Production-ready API using Inference Endpoints\n",
    "\n",
    "The 🤗 Inference API is a free-to-use tool to quickly try a large panel of open-source models. It contains a large free-tier plan to play with it but is not suitable for production purposes. The solution for that is 🤗 Inference Endpoints which is a secure production-ready product to easily deploy any Transformers, Sentence-Transformers and Diffusion models hosted on the Hub on a dedicated infrastructure managed by Hugging Face. To get to know more about it and start your first Endpoint, check out the [documentation](https://huggingface.co/docs/inference-endpoints/index).\n",
    "\n",
    "Once your Inference Endpoint is deployed, you get an URL exposing an API to your model. This URL can be directly passed to `HubModel`. Change from the free-to-use API to a production-ready solution with a single line of code!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = HubModel(\"https://endpoint-id.region.vendor.endpoints.huggingface.cloud\", api_key=\"hf_***\")\n",
    "\n",
    "# model.run(\"My text input\", ...)\n",
    "# Prompter(model).fit(...)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "vscode": {
   "interpreter": {
    "hash": "f6d884a85d6833415ff5d0644d8efa773af5069f6068d8bbcc53dcdf3529c2ab"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
